import sys
import time
import signal
import argparse
import os
from pathlib import Path

import numpy as np
import torch
import visdom
import data
from magic import MAGIC
from magicLRR import MAGICLRR
from magicLRR_3layers import MAGICLRR_3LAYERS
from magicSubschedulerLRR import MAGICSubschedulerLRR
from magicTNNLRR import MAGICTNNLRR
from stnnr import STNNR
from utils import *
from action_utils import parse_action_args
from trainer import Trainer, TrainerLRR
from multi_processing import MultiProcessTrainer
import gym

from tensorboardX import SummaryWriter

gym.logger.set_level(40)

torch.utils.backcompat.broadcast_warning.enabled = True
torch.utils.backcompat.keepdim_warning.enabled = True

torch.set_default_tensor_type('torch.DoubleTensor')

parser = argparse.ArgumentParser(description='Multi-Agent Graph Attention Communication')

# training
parser.add_argument('--num_epochs', default=100, type=int,
                    help='number of training epochs')
parser.add_argument('--epoch_size', type=int, default=10,
                    help='number of update iterations in an epoch')
parser.add_argument('--batch_size', type=int, default=500,
                    help='number of steps before each update (per thread)')
parser.add_argument('--nprocesses', type=int, default=16,
                    help='How many processes to run')

# model
parser.add_argument('--hid_size', default=64, type=int,
                    help='hidden layer size')
parser.add_argument('--directed', action='store_true', default=False,
                    help='whether the communication graph is directed')
parser.add_argument('--self_loop_type1', default=2, type=int,
                    help='self loop type in the first gat layer (0: no self loop, 1: with self loop, 2: decided by hard attn mechanism)')
parser.add_argument('--self_loop_type2', default=2, type=int,
                    help='self loop type in the second gat layer (0: no self loop, 1: with self loop, 2: decided by hard attn mechanism)')
parser.add_argument('--self_loop_type3', default=2, type=int,
                    help='self loop type in the second gat layer (0: no self loop, 1: with self loop, 2: decided by hard attn mechanism)')
parser.add_argument('--gat_num_heads', default=1, type=int,
                    help='number of heads in gat layers except the last one')
parser.add_argument('--gat_num_regulated_heads', default=1, type=int,
                    help='number of nuclear norm regulated heads in gat layers except the last one')
parser.add_argument('--gat_num_heads_out', default=1, type=int,
                    help='number of heads in output gat layer')
parser.add_argument('--gat_num_regulated_heads_out', default=1, type=int,
                    help='number of nuclear norm regulated heads in output gat layer')
parser.add_argument('--gat_hid_size', default=64, type=int,
                    help='hidden size of one head in gat')
parser.add_argument('--ge_num_heads', default=4, type=int,
                    help='number of heads in the gat encoder')
parser.add_argument('--first_gat_normalize', action='store_true', default=False,
                    help='whether normalize the coefficients in the first gat layer of the message processor')
parser.add_argument('--second_gat_normalize', action='store_true', default=False,
                    help='whether normalize the coefficients in the second gat layer of the message processor')
parser.add_argument('--third_gat_normalize', action='store_true', default=False,
                    help='whether normalize the coefficients in the second gat layer of the message processor')
parser.add_argument('--gat_encoder_normalize', action='store_true', default=False,
                    help='whether normilize the coefficients in the gat encoder (they have been normalized if the input graph is complete)')
parser.add_argument('--use_gat_encoder', action='store_true', default=False,
                    help='whether use the gat encoder before learning the first graph')
parser.add_argument('--use_gat_v2_SharedW_encoder', action='store_true', default=False,
                    help='whether use the gat_v2_SharedW encoder before learning the first graph')
parser.add_argument('--gat_encoder_out_size', default=64, type=int,
                    help='hidden size of output of the gat encoder')
parser.add_argument('--first_graph_complete', action='store_true', default=False,
                    help='whether the first communication graph is set to a complete graph')
parser.add_argument('--second_graph_complete', action='store_true', default=False,
                    help='whether the second communication graph is set to a complete graph')
parser.add_argument('--third_graph_complete', action='store_true', default=False,
                    help='whether the second communication graph is set to a complete graph')
parser.add_argument('--learn_second_graph', action='store_true', default=False,
                    help='whether learn a new communication graph at the second round of communication')
parser.add_argument('--learn_third_graph', action='store_true', default=False,
                    help='whether learn a new communication graph at the second round of communication')
parser.add_argument('--message_encoder', action='store_true', default=False,
                    help='whether use the message encoder')
parser.add_argument('--message_decoder', action='store_true', default=False,
                    help='whether use the message decoder')
parser.add_argument('--nagents', type=int, default=1,
                    help="number of agents")
parser.add_argument('--mean_ratio', default=0, type=float,
                    help='how much coooperative to do? 1.0 means fully cooperative')
parser.add_argument('--detach_gap', default=10000, type=int,
                    help='detach hidden state and cell state for rnns at this interval')
parser.add_argument('--comm_init', default='uniform', type=str,
                    help='how to initialise comm weights [uniform|zeros]')
parser.add_argument('--advantages_per_action', default=False, action='store_true',
                    help='whether to multipy log porb for each chosen action with advantages')
parser.add_argument('--comm_mask_zero', action='store_true', default=False,
                    help="whether block the communication")

# optimization
parser.add_argument('--gamma', type=float, default=1.0,
                    help='discount factor')
parser.add_argument('--seed', type=int, default=-1,
                    help='random seed')
parser.add_argument('--normalize_rewards', action='store_true', default=False,
                    help='normalize rewards in each batch')
parser.add_argument('--lrate', type=float, default=0.001,
                    help='learning rate')
parser.add_argument('--entr', type=float, default=0,
                    help='entropy regularization coeff')
parser.add_argument('--value_coeff', type=float, default=0.01,
                    help='coefficient for value loss term')
parser.add_argument('--low_rank_coeff', type=float, default=0.1,
                    help='coefficient for rank loss term')
parser.add_argument('--high_rank_coeff1', type=float, default=0, # 0.005
                    help='coefficient for attention1 rank loss')
parser.add_argument('--high_rank_coeff2', type=float, default=0,
                    help='coefficient for attention2 rank loss')
parser.add_argument('--high_rank_coeff3', type=float, default=0,
                    help='coefficient for attention2 rank loss')
parser.add_argument('--value_rank_coeff', type=float, default=0.01,
                    help='coefficient for rank loss term')

# environment
parser.add_argument('--env_name', default="grf",
                    help='name of the environment to run')
parser.add_argument('--max_steps', default=20, type=int,
                    help='force to end the game after this many steps')
parser.add_argument('--nactions', default='1', type=str,
                    help='the number of agent actions')
parser.add_argument('--action_scale', default=1.0, type=float,
                    help='scale action output from model')

# other
parser.add_argument('--plot', action='store_true', default=False,
                    help='plot training progress')
parser.add_argument('--plot_env', default='main', type=str,
                    help='plot env name')
parser.add_argument('--plot_port', default='8097', type=str,
                    help='plot port')
parser.add_argument('--save', action="store_true", default=False,
                    help='save the model after training')
parser.add_argument('--save_every', default=0, type=int,
                    help='save the model after every n_th epoch')
parser.add_argument('--load', default='', type=str,
                    help='load the model')
parser.add_argument('--display', action="store_true", default=False,
                    help='display environment state')
parser.add_argument('--random', action='store_true', default=False,
                    help="enable random model")
parser.add_argument('--method', type=str,
                    help='Choose the method you want to use, magic or magic-LRR.')

parser.add_argument('--use_gat', action='store_true', default=False,
                    help="use vanilla GAT for static attention")
parser.add_argument('--use_gat_v2', action='store_true', default=False,
                    help="use improved GAT-v2 for dynamic attention")
parser.add_argument('--use_general_attention', action='store_true', default=False,
                    help="use general attention copied from DICG")
parser.add_argument('--use_gat_v2_SharedW', action='store_true', default=False,
                    help="use improved GAT-v2-shared-W for dynamic attention")
parser.add_argument('--gumbel_matrix_third_way', default=3, type=int,
                    help='dimension of the third way of gumbel matrix')
parser.add_argument('--fft_times', default=1, type=int,
                    help='times of the dimension after fft')
parser.add_argument('--TNN_mode1', default='', type=str,
                    help='set the TNN1 mode')
parser.add_argument('--TNN_mode2', default='', type=str,
                    help='set the TNN2 mode')

init_args_for_env(parser)
args = parser.parse_args()

args.nfriendly = args.nagents
if hasattr(args, 'enemy_comm') and args.enemy_comm:
    if hasattr(args, 'nenemies'):
        args.nagents += args.nenemies
    else:
        raise RuntimeError("Env. needs to pass argument 'nenemy'.")

if args.env_name == 'grf':
    render = args.render
    args.render = False
env = data.init(args.env_name, args, False)

args.obs_size = env.observation_dim
args.num_actions = env.num_actions

# Multi-action
if not isinstance(args.num_actions, (list, tuple)):  # single action case
    args.num_actions = [args.num_actions]
args.dim_actions = env.dim_actions

parse_action_args(args)

if args.seed == -1:
    args.seed = np.random.randint(0, 10000)
torch.manual_seed(args.seed)

print(args)

if args.method == 'magic':
    policy_net = MAGIC(args)
elif args.method == 'magicLRR':
    policy_net = MAGICLRR(args)
elif args.method == 'magicLRR_3layers':
    policy_net = MAGICLRR_3LAYERS(args)
elif args.method == 'magicSubschedulerLRR':
    policy_net = MAGICSubschedulerLRR(args)
elif args.method == 'magicTNNLRR':
    policy_net = MAGICTNNLRR(args)
elif args.method == 'STNNR':
    policy_net = STNNR(args)
else:
    raise NotImplementedError

if not args.display:
    display_models([policy_net])

# share parameters among threads, but not gradients
for p in policy_net.parameters():
    p.data.share_memory_()

disp_trainer = Trainer(args, policy_net, data.init(args.env_name, args, False))
disp_trainer.display = True


def disp():
    x = disp_trainer.get_episode()


if args.env_name == 'grf':
    args.render = render
if args.nprocesses > 1:
    if args.method == 'magic':
        trainer = MultiProcessTrainer(args, lambda: Trainer(args, policy_net, data.init(args.env_name, args)))
    else:
        trainer = MultiProcessTrainer(args, lambda: TrainerLRR(args, policy_net, data.init(args.env_name, args)))
else:
    if args.method == 'magic':
        trainer = Trainer(args, policy_net, data.init(args.env_name, args))
    else:
        trainer = TrainerLRR(args, policy_net, data.init(args.env_name, args))

log = dict()
log['epoch'] = LogField(list(), False, None, None)
log['reward'] = LogField(list(), True, 'epoch', 'num_episodes')
log['enemy_reward'] = LogField(list(), True, 'epoch', 'num_episodes')
log['success'] = LogField(list(), True, 'epoch', 'num_episodes')
log['steps_taken'] = LogField(list(), True, 'epoch', 'num_episodes')
log['add_rate'] = LogField(list(), True, 'epoch', 'num_episodes')
log['comm_action'] = LogField(list(), True, 'epoch', 'num_steps')
log['enemy_comm'] = LogField(list(), True, 'epoch', 'num_steps')
log['value_loss'] = LogField(list(), True, 'epoch', 'num_steps')
log['action_loss'] = LogField(list(), True, 'epoch', 'num_steps')
log['entropy'] = LogField(list(), True, 'epoch', 'num_steps')

log['rank_to_low'] = LogField(list(), True, 'epoch', 'num_steps')
log['rank_to_high1'] = LogField(list(), True, 'epoch', 'num_steps')
log['rank_to_high2'] = LogField(list(), True, 'epoch', 'num_steps')

if args.plot:
    vis = visdom.Visdom(env=args.plot_env, port=args.plot_port)

model_dir = Path('./saved') / args.env_name
if args.env_name == 'grf':
    model_dir = model_dir / args.scenario
if not model_dir.exists():
    curr_run = 'run1'
    log_dir = model_dir / 'run1' / 'tensorboard'
else:
    exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in
                     model_dir.iterdir() if
                     str(folder.name).startswith('run')]
    if len(exst_run_nums) == 0:
        curr_run = 'run1'
        log_dir = model_dir / 'run1' / 'tensorboard'
    else:
        curr_run = 'run%i' % (max(exst_run_nums) + 1)
        last_run = 'run%i' % (max(exst_run_nums) + 1)
        log_dir = model_dir / last_run / 'tensorboard'

run_dir = model_dir / curr_run


def run(num_epochs):
    num_episodes = 0
    if args.save:
        os.makedirs(run_dir)
    writer = SummaryWriter(log_dir)
    for ep in range(num_epochs):  # 这里使用epoch去代替了rollout
        epoch_begin_time = time.time()
        stat = dict()
        for n in range(args.epoch_size):
            if n == args.epoch_size - 1 and args.display:
                trainer.display = True
            s = trainer.train_batch(ep)
            print('batch: ', n)
            merge_stat(s, stat)
            trainer.display = False

        epoch_time = time.time() - epoch_begin_time
        epoch = len(log['epoch'].data) + 1
        num_episodes += stat['num_episodes']
        for k, v in log.items():
            if k == 'epoch':
                v.data.append(epoch)
            else:
                if k in stat and v.divide_by is not None and stat[v.divide_by] > 0:
                    stat[k] = stat[k] / stat[v.divide_by]
                v.data.append(stat.get(k, 0))

        np.set_printoptions(precision=2)

        print('Epoch {}'.format(epoch))
        print('Episode: {}'.format(num_episodes))
        print('Reward: {}'.format(stat['reward']))
        print('Average Reward: {}'.format(np.average(stat['reward'])))
        print('Time: {:.2f}s'.format(epoch_time))
        print('-------------------------------')

        # tabular.record('Epoch', epoch)
        writer.add_scalar('Reward', np.average(stat['reward']), global_step=epoch)
        writer.add_scalar('Action loss', stat['action_loss'], global_step=epoch)
        writer.add_scalar('Value loss', stat['value_loss'], global_step=epoch)
        writer.add_scalar('Action entropy', stat['entropy'], global_step=epoch)
        if args.method == 'magicLRR' or args.method == 'magicSubschedulerLRR' or args.method == 'magicTNNLRR' or args.method == 'STNNR':
            # writer.add_scalar('rank_to_low', stat['rank_to_low'], global_step=epoch)
            writer.add_scalar('rank_to_high1', stat['rank_to_high1'], global_step=epoch)
            writer.add_scalar('rank_to_high2', stat['rank_to_high2'], global_step=epoch)
            writer.add_scalar('attention1_every_head_rank', average(stat['attention1_every_head_rank']), global_step=epoch)
            writer.add_scalar('attention1_all_head_rank', average(stat['attention1_all_head_rank']), global_step=epoch)
            writer.add_scalar('attention1_sorted_every_head_rank', average(stat['attention1_sorted_every_head_rank']), global_step=epoch)
            writer.add_scalar('attention2_every_head_rank', average(stat['attention2_every_head_rank']), global_step=epoch)
            writer.add_scalar('attention2_all_head_rank', average(stat['attention2_all_head_rank']), global_step=epoch)
            writer.add_scalar('attention2_sorted_every_head_rank', average(stat['attention2_sorted_every_head_rank']),
                              global_step=epoch)
            if args.method == 'magicSubschedulerLRR' or args.method == 'STNNR':
                writer.add_scalar('sub_scheduler_attention1_every_head_rank', average(stat['sub_scheduler_attention1_every_head_rank']),
                                  global_step=epoch)
                # writer.add_scalar('sub_scheduler_attention1_all_head_rank', average(stat['sub_scheduler_attention1_all_head_rank']),
                #                   global_step=epoch)
                # writer.add_scalar('sub_scheduler_attention1_sorted_every_head_rank',
                #                   average(stat['sub_scheduler_attention1_sorted_every_head_rank']), global_step=epoch)
        if args.method=='magicLRR_3layers':
            writer.add_scalar('rank_to_high1', stat['rank_to_high1'], global_step=epoch)
            writer.add_scalar('rank_to_high2', stat['rank_to_high2'], global_step=epoch)
            writer.add_scalar('attention1_every_head_rank', average(stat['attention1_every_head_rank']),
                              global_step=epoch)
            writer.add_scalar('attention1_all_head_rank', average(stat['attention1_all_head_rank']), global_step=epoch)
            writer.add_scalar('attention1_sorted_every_head_rank', average(stat['attention1_sorted_every_head_rank']),
                              global_step=epoch)
            writer.add_scalar('attention2_every_head_rank', average(stat['attention2_every_head_rank']),
                              global_step=epoch)
            writer.add_scalar('attention2_all_head_rank', average(stat['attention2_all_head_rank']), global_step=epoch)
            writer.add_scalar('attention2_sorted_every_head_rank', average(stat['attention2_sorted_every_head_rank']),
                              global_step=epoch)
            writer.add_scalar('attention3_every_head_rank', average(stat['attention2_every_head_rank']),
                              global_step=epoch)
            writer.add_scalar('attention3_all_head_rank', average(stat['attention2_all_head_rank']), global_step=epoch)
            writer.add_scalar('attention3_sorted_every_head_rank', average(stat['attention2_sorted_every_head_rank']),
                              global_step=epoch)


        if 'enemy_reward' in stat.keys():
            average_enemy_reward  = average(stat['enemy_reward'])
            print('Enemy-Reward: {}'.format(average_enemy_reward))
            writer.add_scalar('Enemy-Reward', average_enemy_reward, global_step=epoch)
        if 'add_rate' in stat.keys():
            print('Add-Rate: {:.2f}'.format(stat['add_rate']))
            writer.add_scalar('Add-Rate', stat['add_rate'], global_step=epoch)
        if 'success' in stat.keys():
            print('Success: {:.4f}'.format(stat['success']))
            writer.add_scalar('Success', stat['success'], global_step=epoch)
            print(stat['success'])
        if 'steps_taken' in stat.keys():
            print('Steps-Taken: {:.2f}'.format(stat['steps_taken']))
            writer.add_scalar('Steps-Taken', stat['steps_taken'], global_step=epoch)
        if 'comm_action' in stat.keys():
            print('Comm-Action: {}'.format(stat['comm_action']))
            writer.add_scalar('Comm-Action', stat['comm_action'], global_step=epoch)
        if 'enemy_comm' in stat.keys():
            print('Enemy-Comm: {}'.format(stat['enemy_comm']))
            writer.add_scalar('Enemy-Comm', stat['enemy_comm'], global_step=epoch)

        if args.plot:
            for k, v in log.items():
                if v.plot and len(v.data) > 0:
                    vis.line(np.asarray(v.data), np.asarray(log[v.x_axis].data[-len(v.data):]),
                             win=k, opts=dict(xlabel=v.x_axis, ylabel=k))

        if args.save_every and ep and args.save and (ep + 1) % args.save_every == 0:
            save(final=False, epoch=ep + 1)

        if args.save:
            save(final=True)
    writer.close()


def average(*argss):
    return sum(argss[0], 0.0) / len(argss[0])


def save(final, epoch=0):
    d = dict()
    d['policy_net'] = policy_net.state_dict()
    d['log'] = log
    d['trainer'] = trainer.state_dict()
    if final:
        torch.save(d, run_dir / 'model.pt')
    else:
        torch.save(d, run_dir / ('model_ep%i.pt' % (epoch)))


def load(path):
    d = torch.load(path)
    # log.clear()
    policy_net.load_state_dict(d['policy_net'])
    log.update(d['log'])
    trainer.load_state_dict(d['trainer'])
    print('Model loaded successfully')


def signal_handler(signal, frame):
    print('You pressed Ctrl+C! Exiting gracefully.')
    if args.display:
        env.end_display()
    sys.exit(0)


signal.signal(signal.SIGINT, signal_handler)

if args.load != '':
    load(args.load)

run(args.num_epochs)
if args.display:
    env.end_display()

if args.save:
    save(final=True)

if sys.flags.interactive == 0 and args.nprocesses > 1:
    trainer.quit()
    import os

    os._exit(0)
